package org.hepeng.workx.spring.cloud.netflix.ribbon.loadbalancer;

import com.netflix.client.config.IClientConfig;
import com.netflix.loadbalancer.AbstractLoadBalancerRule;
import com.netflix.loadbalancer.ILoadBalancer;
import com.netflix.loadbalancer.IRule;
import com.netflix.loadbalancer.Server;
import com.netflix.loadbalancer.ZoneAvoidanceRule;
import com.netflix.zuul.context.RequestContext;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.hepeng.workx.spring.cloud.netflix.ribbon.RibbonRequestContext;
import org.hepeng.workx.spring.context.ApplicationContextHolder;
import org.joor.Reflect;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cloud.netflix.zuul.filters.route.RibbonRoutingFilter;
import org.springframework.util.Assert;
import org.springframework.util.MultiValueMap;

import javax.servlet.http.HttpServletRequest;
import java.util.List;
import java.util.Objects;


/**
 * @author he peng
 */
public abstract class AbstractThreadIsolationRule extends AbstractLoadBalancerRule {

    private static final Logger LOG = LoggerFactory.getLogger(AbstractThreadIsolationRule.class);

    protected IRule rule = new ZoneAvoidanceRule();
    protected String paramName;
    protected ThreadIsolationLoadBalancer threadIsolationLoadBalancer;

    public AbstractThreadIsolationRule() {
        this.paramName = "load-balance";
    }

    public AbstractThreadIsolationRule(String paramName) {
        Assert.hasLength(paramName , "paramName must not be blank string");
        this.paramName = paramName;
    }

    public AbstractThreadIsolationRule(String paramName , IRule rule) {
        this(paramName);
        Assert.notNull(rule , "rule must not be null");
        this.rule = rule;
    }

    @Override
    public void setLoadBalancer(ILoadBalancer lb) {
        this.threadIsolationLoadBalancer = new ThreadIsolationLoadBalancer(lb);
        super.setLoadBalancer(threadIsolationLoadBalancer);
        this.rule.setLoadBalancer(threadIsolationLoadBalancer);
    }

    @Override
    public ILoadBalancer getLoadBalancer() {
        return super.getLoadBalancer();
    }

    @Override
    public void initWithNiwsConfig(IClientConfig clientConfig) {

    }

    @Override
    public Server choose(Object key) {
        RibbonRequestContext ribbonContext = getRibbonRequestContext();
        if (Objects.isNull(ribbonContext)) {
            return this.rule.choose(key);
        }

        ILoadBalancer lb = getLoadBalancer();
        List<Server> originalAllServers = lb.getAllServers();

        LOG.info("Original All Servers -> {} " , originalAllServers);
        Server server;
        try {
            if (isSupported(key , ribbonContext)) {
                List<Server> loadBalancingServers = getParticipateLoadBalancingServers(key, ribbonContext, originalAllServers);
                if (CollectionUtils.isNotEmpty(loadBalancingServers)) {
                    resetAllServers(loadBalancingServers);
                }
                LOG.info("LoadBalancing All Servers -> {} " , loadBalancingServers);
            }

            server = this.rule.choose(key);
            LOG.info(this.rule.getClass().getName() + " Choose Server -> {} " , server);
            return server;
        } finally {
            this.threadIsolationLoadBalancer.clear();
        }
    }

    protected void resetAllServers(List<Server> allServers) {
        this.threadIsolationLoadBalancer.setAllServerList(allServers);
    }

    protected RibbonRequestContext getRibbonRequestContext() {
        Object commandContext = null;
        try {
            HttpServletRequest request = RequestContext.getCurrentContext().getRequest();
            if (Objects.nonNull(request)) {
                RibbonRoutingFilter ribbonRoutingFilter = ApplicationContextHolder.getApplicationContext()
                        .getBean(RibbonRoutingFilter.class);
                commandContext = Reflect.on(ribbonRoutingFilter)
                        .call("buildCommandContext", RequestContext.getCurrentContext())
                        .get();
            }
        } catch (Throwable t) {
            LOG.error("create " + RibbonRequestContext.class + " error" , t);
            return null;
        }
        return Objects.nonNull(commandContext) ? RibbonRequestContext.copy(commandContext) : null;
    }

    public abstract List<Server> getParticipateLoadBalancingServers(Object key , RibbonRequestContext ribbonContext , List<Server> originalAllServers);

    protected boolean isSupported(Object key, RibbonRequestContext ribbonContext) {
        MultiValueMap<String, String> params = ribbonContext.getParams();
        boolean paramSupported = Objects.nonNull(params) && params.containsKey(paramName);
        MultiValueMap<String, String> headers = ribbonContext.getHeaders();
        boolean headerSupported = Objects.nonNull(headers) && headers.containsKey(paramName);
        return paramSupported || headerSupported;
    }

    protected String getDecideLoadBalanceParam(RibbonRequestContext ribbonContext) {
        String decideLoadBalanceParam = null;
        MultiValueMap<String, String> params = ribbonContext.getParams();
        MultiValueMap<String, String> headers = ribbonContext.getHeaders();
        if (MapUtils.isNotEmpty(params)) {
            decideLoadBalanceParam = params.getFirst(paramName);
        } else if(MapUtils.isNotEmpty(headers)) {
            decideLoadBalanceParam = headers.getFirst(paramName);
        }
        return decideLoadBalanceParam;
    }
}
